"""
 @Author: zzgsg
 @time: 2024/9/28 20:32
"""

# Figure 6. Profiles retrieved by USTC-PRM and OEM in CAMS.

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
import matplotlib

def hide_spines(ax):
    # 隐藏边框
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    # 显示坐标轴标签
    ax.tick_params(axis='both', which='both', labelsize=8)  # 调整坐标轴标签大小
    ax.tick_params(axis='both', which='major', length=0)  # 隐藏刻度线
    ax.set_xticks([])
    ax.set_yticks([])

def ax_plot(ax, x, y, xlabel, ylabel, xlim, ylim, rstr='R={:.3f}', linestr='y={:.2f}x+{:.2f}', order = '(a)', LPAP=0):
    bool2 = ~(np.isnan(x) | np.isnan(y))
    x, y = x[bool2], y[bool2]
    xy = np.vstack([x, y])
    z = gaussian_kde(xy)(xy)
    idx = z.argsort()
    x, y, z = x[idx], y[idx], z[idx]
    plt.sca(ax)
    plt.scatter(x, y, marker='o', c=z, edgecolors=['none'], s=25, label='LST'
                          , cmap='Spectral_r')  # colorbar :Spectral_r
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.grid(linestyle='-.')
    plt.axline(xy1=(0,np.polyfit(x, y, 1)[1]), xy2=(1,1*np.polyfit(x, y, 1)[0]+np.polyfit(x, y, 1)[1]),
               linestyle='--', color='r')
    # plt.plot(np.arange(300), np.polyfit(a, b, 1)[0] * np.arange(300) + np.polyfit(a, b, 1)[1], c='r', linestyle='--')
    plt.annotate(linestr.format(np.polyfit(x, y, 1)[0], np.polyfit(x, y, 1)[1]),
                 xy=(0.8, 0.2), xycoords='axes fraction',
                 horizontalalignment='center', verticalalignment='top')
    plt.annotate(rstr.format(np.corrcoef(x, y)[0, 1]), xy=(0.8, 0.1), xycoords='axes fraction',
                 horizontalalignment='center', verticalalignment='top')
    plt.annotate(order, xy=(0.1, 0.9), xycoords='axes fraction', fontsize=20,
                 horizontalalignment='center', verticalalignment='top')
    # if LPAP==1:
    #     plt.annotate('LPAP', xy=(0.9, 0.9), xycoords='axes fraction', fontsize=15,
    #                  horizontalalignment='center', verticalalignment='top',)
    #              #     bbox = dict(boxstyle='round,pad=0.5', edgecolor='black', facecolor='lightblue',alpha=0.4)
    #              # )
    plt.xlim(xlim)
    plt.ylim(ylim)

def axs_deal(axs):
    order = 'abcdefghijklmnopqrstuvwxyz'
    data = np.load('Long_term_data/extinction_oppa_vs_gk.npy')
    ax_plot(axs[0], data[1]/100, data[0], 'GuanYuan in situ PM2.5 [${10}^{2}$ μg/$m^{3}$]', 'Extinction @360nm [1/km]',
            [0,3], [0,3], order='(%s)'%order[0])
    data = np.load('Long_term_data/aod_oppa_vs_aeronet.npy')
    ax_plot(axs[1], data[1], data[0], 'AERONET AOD [1/km]', 'AOD @360nm [1/km]',
            [0,3], [0,3], order='(%s)'%order[1])
    data = np.load('Long_term_data/extinction_oppa_vs_tower.npy')
    ax_plot(axs[2], data[3]/100, data[0], 'Tower 60m PM2.5 [${10}^{2}$ μg/$m^{3}$]', '0-100m Extinction @360nm [1/km]',
            [0,2], [0,2], order='(%s)'%order[2])
    ax_plot(axs[3], data[4]/100, data[1], 'Tower 160m PM2.5 [${10}^{2}$ μg/$m^{3}$]', '100-200m Extinction @360nm [1/km]',
            [0,2], [0,2], order='(%s)'%order[3])
    ax_plot(axs[4], data[5]/100, data[2], 'Tower 280m PM2.5 [${10}^{2}$ μg/$m^{3}$]', '200-300m Extinction @360nm [1/km]',
            [0,2], [0,2], order='(%s)'%order[4])

    data = np.load('Long_term_data/extinction_oe_vs_gk.npy')
    ax_plot(axs[0+5], data[1]/100, data[0], 'GuanYuan in situ PM2.5 [${10}^{2}$ μg/$m^{3}$]', 'Extinction @360nm [1/km]',
            [0,3], [0,3], order='(%s)'%order[5])
    data = np.load('Long_term_data/aod_oe_vs_aeronet.npy')
    ax_plot(axs[1+5], data[1], data[0], 'AERONET AOD [1/km]', 'AOD @360nm [1/km]',
            [0,3], [0,3], order='(%s)'%order[6])
    data = np.load('Long_term_data/extinction_oe_vs_tower.npy')
    ax_plot(axs[2+5], data[3]/100, data[0], 'Tower 60m PM2.5 [${10}^{2}$ μg/$m^{3}$]', '0-100m Extinction @360nm [1/km]',
            [0,2], [0,2], order='(%s)'%order[7])
    ax_plot(axs[3+5], data[4]/100, data[1], 'Tower 160m PM2.5 [${10}^{2}$ μg/$m^{3}$]', '100-200m Extinction @360nm [1/km]',
            [0,2], [0,2], order='(%s)'%order[8])
    ax_plot(axs[4+5], data[5]/100, data[2], 'Tower 280m PM2.5 [${10}^{2}$ μg/$m^{3}$]', '200-300m Extinction @360nm [1/km]',
            [0,2], [0,2], order='(%s)'%order[9])

    data = np.load('Long_term_data/no2_oppa_vs_gk.npy')
    ax_plot(axs[5+5], data[1], data[0], 'GuanYuan in situ $NO_2$ [μg/$m^{3}$]', '$NO_2$ [μg/$m^{3}$]',
            [0,140], [0,140], order='(%s)'%order[10])
    data = np.load('Long_term_data/no2_oppa_vs_tower.npy')
    ax_plot(axs[6+5], data[3], data[0], 'Tower 60m $NO_2$ [μg/$m^{3}$]', '0-100m $NO_2$ [μg/$m^{3}$]',
            [0,80], [0,80], order='(%s)'%order[11])
    ax_plot(axs[7+5], data[4], data[1], 'Tower 160m $NO_2$ [μg/$m^{3}$]', '100-200m $NO_2$ [μg/$m^{3}$]',
            [0,70], [0,70], order='(%s)'%order[12])
    ax_plot(axs[8+5], data[5], data[2], 'Tower 280m $NO_2$ [μg/$m^{3}$]', '200-300m $NO_2$ [μg/$m^{3}$]',
            [0,60], [0,60], order='(%s)'%order[13])

    data = np.load('Long_term_data/no2_oppa_lpap_vs_gk.npy')
    ax_plot(axs[10+5], data[1], data[0], 'GuanYuan in situ $NO_2$ [μg/$m^{3}$]', '$NO_2$ [μg/$m^{3}$]',
            [0,140], [0,140], LPAP=1, order='(%s)'%order[14])
    data = np.load('Long_term_data/no2_oppa_lpap_vs_tower.npy')
    ax_plot(axs[11+5], data[3], data[0], 'Tower 60m $NO_2$ [μg/$m^{3}$]', '0-100m $NO_2$ [μg/$m^{3}$]',
            [0,80], [0,80], LPAP=1, order='(%s)'%order[15])
    ax_plot(axs[12+5], data[4], data[1], 'Tower 160m $NO_2$ [μg/$m^{3}$]', '100-200m $NO_2$ [μg/$m^{3}$]',
            [0,70], [0,70], LPAP=1, order='(%s)'%order[16])
    ax_plot(axs[13+5], data[5], data[2], 'Tower 280m $NO_2$ [μg/$m^{3}$]', '200-300m $NO_2$ [μg/$m^{3}$]',
            [0,60], [0,60], LPAP=1, order='(%s)'%order[17])

    data = np.load('Long_term_data/no2_oe_vs_gk.npy')
    ax_plot(axs[5+15], data[1], data[0], 'GuanYuan in situ $NO_2$ [μg/$m^{3}$]', '$NO_2$ [μg/$m^{3}$]',
            [0,140], [0,140], order='(%s)'%order[18])
    data = np.load('Long_term_data/no2_oe_vs_tower.npy')
    ax_plot(axs[6+15], data[3], data[0], 'Tower 60m $NO_2$ [μg/$m^{3}$]', '0-100m $NO_2$ [μg/$m^{3}$]',
            [0,80], [0,80], order='(%s)'%order[19])
    ax_plot(axs[7+15], data[4], data[1], 'Tower 160m $NO_2$ [μg/$m^{3}$]', '100-200m $NO_2$ [μg/$m^{3}$]',
            [0,70], [0,70], order='(%s)'%order[20])
    ax_plot(axs[8+15], data[5], data[2], 'Tower 280m $NO_2$ [μg/$m^{3}$]', '200-300m $NO_2$ [μg/$m^{3}$]',
            [0,60], [0,60], order='(%s)'%order[21])


plt.rcParams['mathtext.default'] = 'default'
plt.rcParams['font.size'] = 12
fig, axs = plt.subplots(5, 5,figsize=(24,20), dpi=90)
plt.subplots_adjust(top=0.97, bottom=0.04, right=0.97, left=0.10, hspace=0.2, wspace=0.22)
axs_deal(axs.flatten())

axs[0, 0].annotate('USTC-PRM\nAEROSOL', xy=(-0.3, 0.5), xycoords='axes fraction', rotation=90, fontsize=20,
             horizontalalignment='center', verticalalignment='center')
axs[1, 0].annotate('OEM\nAEROSOL', xy=(-0.3, 0.5), xycoords='axes fraction', rotation=90, fontsize=20,
             horizontalalignment='center', verticalalignment='center')
axs[2, 0].annotate('USTC-PRM\n$NO_2$', xy=(-0.3, 0.5), xycoords='axes fraction', rotation=90, fontsize=20,
             horizontalalignment='center', verticalalignment='center')
axs[3, 0].annotate('USTC-PRM with LPAP\n$NO_2$', xy=(-0.3, 0.5), xycoords='axes fraction', rotation=90, fontsize=20,
             horizontalalignment='center', verticalalignment='center')
axs[4, 0].annotate('OEM\n$NO_2$', xy=(-0.3, 0.5), xycoords='axes fraction', rotation=90, fontsize=20,
             horizontalalignment='center', verticalalignment='center')
hide_spines(axs.flatten()[-11])
hide_spines(axs.flatten()[-6])
hide_spines(axs.flatten()[-1])
# hide_spines(axs.flatten()[-3])
cax = plt.axes([0.825, 0.1, 0.15, 0.015])
cbar = plt.colorbar(matplotlib.cm.ScalarMappable(norm=matplotlib.colors.Normalize(vmin=0, vmax=1),
                                               cmap='Spectral_r'), cax=cax, ticks=[0.03, 0.97], orientation='horizontal')
cbar.set_label(label='Probability density', fontsize=20)
cbar.ax.set_xticklabels(['Low', 'High'])  # 设置刻度标签
# # 获取颜色条的刻度标签对象并旋转90度
# for label in cbar.ax.get_yticklabels():
#     label.set_rotation(90)
plt.savefig('temp/Figure7.png')
plt.savefig('temp/Figure7.eps')
plt.close()
# cbar.formatter.axis.set_major_formatter(lambda x, pos: '%.1f' % (1000 * x))
